import numpy as np
from scipy import stats
from scipy import constants
from scipy.integrate import quad, simpson, cumtrapz
import pymc as pm
import matplotlib.pyplot as plt
import seaborn as sns

import pytensor.tensor as pt

import arviz as az

import xarray as xr


# importing scientific constants from scipy
h_bar = constants.hbar
k_B = constants.Boltzmann
elementary_charge = constants.elementary_charge



    
def hazard(I, critical_current, f_j0, Temp):
    """
    
    The hazard function for the switching event, can be expressed as
    Arrhenius equation based on applied current
    
    Parameters
    ----------
    I : array-like
        current points    
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    """
    
    # Define the current points
    current = I
    # Define the hazard function
    eta = current/critical_current
    #mask = eta>=1
    #eta[mask]=0.9
    
    freq_j = f_j0*(1-eta**2)**(1/4)
    E_J = h_bar*critical_current/(2*elementary_charge)
    Delta_U = 2*E_J*(np.sqrt(1-eta**2)-eta*np.arccos(eta))
    hazard = freq_j*np.exp(-1*Delta_U/(Temp*k_B))
    #hazard[mask] = 1
    
    return hazard

def failure_pdf(I, ramp_speed, critical_current, f_j0, Temp):
    """ 
    Probability density function for failure times, accounting for a
    time-dependent hazard function
    
    Parameters
    ----------
    I : array-like
        current points
    ramp_speed : float
        constant for ramping speed of the experiment
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    numpoints : integer
        number of points to use in numerical integration
        
    Returns
    -------
    array-like : probability density at the times t
    """
    
    # We use cumtrapz to integrate
    lam = hazard(I, critical_current, f_j0, Temp)
    integral = cumtrapz(lam, I, initial=0)
  
    return hazard(I, critical_current, f_j0, Temp) * np.exp(-integral) * ramp_speed

def failure_cdf(I, ramp_speed, critical_current, f_j0, Temp):
    """ 
    Cumulative density function for failure times, accounting for a
    time-dependent hazard function
    
    Parameters
    ----------
    I : array-like
        time points
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    numpoints : integer
        number of points to use in numerical integration
        
    Returns
    -------
    array-like : cumulative probability density at the times t
    """
    f = failure_pdf(I, ramp_speed, critical_current, f_j0, Temp)

    integral = cumtrapz(f, I, initial=0)
    
    return integral
    
    
class FailureDist(stats.rv_continuous): 
    "distribution of switch times for Josephson Junction hazard rate"
    def __init__(self, ramp_speed, critical_current, f_j0, Temp, *args, **kwargs):
        # set lower bound of support to 0 (upper bound is np.inf by default)
        super().__init__(a=0, *args, **kwargs)
        self.ramp_speed = ramp_speed
        self.critical_current = critical_current
        self.f_j0 = f_j0
        self.Temp = Temp

    def _pdf(self, I):
        return failure_pdf(I, self.ramp_speed, self.critical_current, self.f_j0, self.Temp)
    
    def _cdf(self, I):
        return failure_cdf(I, self.ramp_speed, self.critical_current, self.f_j0, self.Temp)
    
    

def data_exporter(switching_data, critical_current, sample_size, f_j0, Temp, path):
    data_simulated = xr.Dataset(coords = { 'switching_times': switching_data,                                
                                  'critical_current': critical_current,
                                  'sample_size': sample_size,
                                  'f_j0': f_j0,
                                  'Temp' : Temp})
    data_simulated.to_netcdf(path, mode = 'w')

#def prob_short(I,critical_current,ramp,capacitance,Temp):
    
    
def hazard_new(I, critical_current, f_j0, Temp, U_0):
    """
    
    The hazard function for the switching event, can be expressed as
    Arrhenius equation based on applied current
    
    Parameters
    ----------
    I : array-like
        current points    
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    U_0 : float
        Scaling factor for the potential barrier
    """
    
    # Define the current points
    current = I
    # Define the hazard function
    eta = current/critical_current
    #mask = eta>=1
    #eta[mask]=0
    
    freq_j = f_j0*(1-eta**2)**(1/4)
    E_J = h_bar*critical_current/(2*elementary_charge)
    Delta_U = U_0*2*E_J*(np.sqrt(1-eta**2)-eta*np.arccos(eta))
    hazard_new = freq_j*np.exp(-1*Delta_U/(Temp*k_B))
    
    
    return hazard_new

def failure_pdf_new(I, ramp_speed, critical_current, f_j0, Temp, U_0):
    """ 
    Probability density function for failure times, accounting for a
    time-dependent hazard function
    
    Parameters
    ----------
    I : array-like
        current points
    ramp_speed : float
        constant for ramping speed of the experiment
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    numpoints : integer
        number of points to use in numerical integration
    U_0 : float
        Scaling factor for the potential barrier
        
    Returns
    -------
    array-like : probability density at the times t
    """
    
    # We use cumtrapz to integrate
    lam = hazard_new(I, critical_current, f_j0, Temp, U_0)/ramp_speed
    integral = cumtrapz(lam, I, initial=0)
    pdf = (hazard_new(I, critical_current, f_j0, Temp,U_0)/ ramp_speed) * np.exp(-integral)
    dx = (I.max()-I.min())/I.size
    l = [x for x in pdf if ~np.isnan(x)]
    l = np.array(l)
    norm =  np.sum(l)*dx
 
    return pdf/norm

def failure_cdf_new(I, ramp_speed, critical_current, f_j0, Temp, U_0):
    """ 
    Cumulative density function for failure times, accounting for a
    time-dependent hazard function
    
    Parameters
    ----------
    I : array-like
        time points
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    numpoints : integer
        number of points to use in numerical integration
    U_0 : float
        Scaling factor for the potential barrier
        
    Returns
    -------
    array-like : cumulative probability density at the times t
    """
    f = failure_pdf_new(I, ramp_speed, critical_current, f_j0, Temp, U_0)

    integral = cumtrapz(f, I, initial=0)
    
    return integral
    
    
class FailureDist_new(stats.rv_continuous): 
    "distribution of switch times for Josephson Junction hazard rate"
    def __init__(self, ramp_speed, critical_current, f_j0, Temp, U_0, *args, **kwargs):
        # set lower bound of support to 0 (upper bound is np.inf by default)
        super().__init__(a=0, *args, **kwargs)
        self.ramp_speed = ramp_speed
        self.critical_current = critical_current
        self.f_j0 = f_j0
        self.Temp = Temp
        self.U_0 = U_0

    def _pdf(self, I):
        return failure_pdf_new(I, self.ramp_speed, self.critical_current, self.f_j0, self.Temp, self.U_0)
    
    def _cdf(self, I):
        return failure_cdf_new(I, self.ramp_speed, self.critical_current, self.f_j0, self.Temp, self.U_0)


    
def hazard2(I, critical_current, f_j0, Temp, U_0):
    """
    
    The hazard function for the switching event, can be expressed as
    Arrhenius equation based on applied current
    
    Parameters
    ----------
    I : array-like
        current points    
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    """
    
    # Define the current points
    current = I
    # Define the hazard function
    eta = current/critical_current
    #mask = eta>=1
    #eta[mask]=0
    
    freq_j = f_j0*(1-eta**2)**(1/4)
    E_J = h_bar*critical_current/(2*elementary_charge)
    Delta_U = (4*np.sqrt(2)/3)*E_J*(1-eta)**(3/2)
    hazard_new = freq_j*np.exp(-1*Delta_U/(Temp*k_B))
    
    
    return hazard_new

def failure_pdf2(I, ramp_speed, critical_current, f_j0, Temp, U_0):
    """ 
    Probability density function for failure times, accounting for a
    time-dependent hazard function
    
    Parameters
    ----------
    I : array-like
        current points
    ramp_speed : float
        constant for ramping speed of the experiment
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    numpoints : integer
        number of points to use in numerical integration
    U_0 : float
        Scaling factor for the potential barrier
        
    Returns
    -------
    array-like : probability density at the times t
    """
    
    # We use cumtrapz to integrate
    lam = hazard_new(I, critical_current, f_j0, Temp, U_0)/ramp_speed
    integral = cumtrapz(lam, I, initial=0)
    pdf = (hazard_new(I, critical_current, f_j0, Temp,U_0)/ ramp_speed) * np.exp(-integral)
    dx = (I.max()-I.min())/I.size
    l = [x for x in pdf if ~np.isnan(x)]
    l = np.array(l)
    norm =  np.sum(l)*dx
 
    return pdf/norm

def failure_cdf2(I, ramp_speed, critical_current, f_j0, Temp, U_0):
    """ 
    Cumulative density function for failure times, accounting for a
    time-dependent hazard function
    
    Parameters
    ----------
    I : array-like
        time points
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    numpoints : integer
        number of points to use in numerical integration
    U_0 : float
        Scaling factor for the potential barrier
        
    Returns
    -------
    array-like : cumulative probability density at the times t
    """
    f = failure_pdf_new(I, ramp_speed, critical_current, f_j0, Temp, U_0)

    integral = cumtrapz(f, I, initial=0)
    
    return integral
    
    
class FailureDist2(stats.rv_continuous): 
    "distribution of switch times for Josephson Junction hazard rate"
    def __init__(self, ramp_speed, critical_current, f_j0, Temp, U_0, *args, **kwargs):
        # set lower bound of support to 0 (upper bound is np.inf by default)
        super().__init__(a=0, *args, **kwargs)
        self.ramp_speed = ramp_speed
        self.critical_current = critical_current
        self.f_j0 = f_j0
        self.Temp = Temp
        self.U_0 = U_0

    def _pdf(self, I):
        return failure_pdf_new(I, self.ramp_speed, self.critical_current, self.f_j0, self.Temp, self.U_0)
    
    def _cdf(self, I):
        return failure_cdf_new(I, self.ramp_speed, self.critical_current, self.f_j0, self.Temp, self.U_0)